# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Compute F1 for different thresholds.

The file takes as input:
- Gold jsonl file with gold document set.
- Output of running scoring with xattn model.
- Metadata generated by `gen_inference_inputs.py`.
"""

import collections
import math

from absl import app
from absl import flags
from language.quest.common import example_utils
from language.quest.common import jsonl_utils
from language.quest.common import tsv_utils
from language.quest.xattn import xattn_utils
import numpy as np

FLAGS = flags.FLAGS

flags.DEFINE_string("gold", "", "Examples jsonl file with gold predictions.")

flags.DEFINE_string("xattn_outputs", "", "Predictions from xattn model.")

flags.DEFINE_string("metadata", "",
                    "Input tsv file of metadata for xattn predictions.")


def main(unused_argv):
  # Load xattn predictions and metadata.
  metadata = tsv_utils.read_tsv(FLAGS.metadata)
  xattn_outputs = jsonl_utils.read(FLAGS.xattn_outputs)
  gold_examples = example_utils.read_examples(FLAGS.gold)

  for threshold in np.arange(0.0, 1.0, 0.05):
    # Dictionary of query to docs.
    query_to_docs = collections.defaultdict(set)
    for (query, doc_title), output in zip(metadata, xattn_outputs):
      if output["inputs"]["targets_pretokenized"] != xattn_utils.POS_LABEL:
        raise ValueError("Unexpected label: %s" % output["prediction"])
      prob = math.exp(output["score"])
      if prob > threshold:
        query_to_docs[query].add(doc_title)

    # Load initial predictions.
    f1_list = []
    precision_list = []
    recall_list = []
    for example in gold_examples:
      predicted_docs = set(query_to_docs[example.query])
      gold_docs = set(example.docs)
      if not gold_docs:
        continue
      tp = len(gold_docs.intersection(predicted_docs))
      fp = len(predicted_docs.difference(gold_docs))
      fn = len(gold_docs.difference(predicted_docs))
      if tp:
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
      else:
        precision = 0.0
        recall = 0.0
        f1 = 0.0
      precision_list.append(precision)
      recall_list.append(recall)
      f1_list.append(f1)

    avg_f1 = sum(f1_list) / len(f1_list)
    avg_precision = sum(precision_list) / len(precision_list)
    avg_recall = sum(recall_list) / len(recall_list)

    print("avg. f1 @ %.2f: %.3f" % (threshold, avg_f1))
    print("avg. precision @ %.2f: %.3f" % (threshold, avg_precision))
    print("avg. recall @ %.2f: %.3f" % (threshold, avg_recall))


if __name__ == "__main__":
  app.run(main)
